# -*- coding: utf-8 -*-


import gurobipy as gr
import numpy as np
import root_const
import pandas as pd
from root_const import ROOT_PATH
from server_placement.server_placement_basic import ServerPlacementBasic

class ServerPlacementILPBalance(ServerPlacementBasic):
    
    """
    Sub class of server_place : using integer linear programming but we use different time to strengthen the constrains
    """
    
    def __init__(self, location_array, k_near, compute_num, lanti1, long1,
                 lanti_list, long_list, frequ_list, adjust_rate):
        """
        
        param
        ----------
            location_array : array
                the array to store both the lantitude and longtitude information of the location
                structure : nx2
                content : first column ---> lantitude , second column ---> longitude
            k_near : int
                the number of the near base station we will find
            compute_num : int
                the number of the servers we want to assign
            lanti1 : list
                store all the lantitude information of the base
            long1 : list
                store all the longtitude information of the base
        """
        ServerPlacementBasic._init_(self, compute_num)
        self.location_array = location_array
        self.k_near = k_near
        self.lanti1 = lanti1
        self.long1 = long1
        self.lanti_list = lanti_list
        self.long_list = long_list
        self.frequ_list = frequ_list
        self.adjust_rate = adjust_rate
            
    def load_distribution(self, lanti_1, long_1, lanti_2, long_2, load):
        """find the locations in the loca2 that are emerged in the loca1 
        
        param
        -----------
            lanti_1 : list
                store all the lantitude information of the loca1
            lanti_2 : list
                store all the lantitude information of the loca2
            long_1 : list
                store all the longtitude information of the loca1
            long_2 : list
                store all the longtitude information of the loca2
            load : list
                store all the load information of the loca2
            
        return
        ----------
            load_rebuild : list
                store the load in both loca1 and loca2, others are 0
            num : int
                the number of the loca included in both loca1 and loca2
        
        refer
        ----------
            Other func : None
        """
        base_num = len(lanti_1)
        num = 0
        load_rebuild = [0 for i in range(base_num)]
        
        for i in range(len(lanti_2)):    
            try:
                lanti_index = lanti_1.index(lanti_2[i])
            except:
                lanti_index = -1
        
            try:
                long_index = long_1.index(long_2[i])
            except:
                long_index = -2
                
            if lanti_index == long_index:
                load_rebuild[lanti_index] = load[i]
                num = num + 1
        return load_rebuild, num
    
#    def sum_distribution(self):
#        """
#        rebuild the data to change one day data into several part
#        """
#        time_list = ['00', '01', '02', '03']
#        path = ROOT_PATH + '/data/telecom_data/6_5sum and 6_6/'
#        final_rebuild = []
#        
#        for i in range(len(time_list)):
#            used_path = path + '6_1To5 ' + time_list[i] + '_frequ.csv'
#            tsv_information = pd.read_csv(used_path)
#            lanti = tsv_information['lanti']
#            lanti2= lanti.tolist()
#            long = tsv_information['long']
#            long2 = long.tolist()
#            frequ2 = tsv_information['frequ']
#            frequ = frequ2.tolist()
#            load_rebuild, num = self.load_distribution(self.lanti1, self.long1, lanti2, long2, frequ)
#            final_rebuild.append(load_rebuild)
#        return final_rebuild
    
    def sum_distribution(self):
        """
        rebuild the data to change one day data into several part
        """
        final_rebuild = []
        
        for i in range(len(self.lanti_list)):
            load_rebuild, num = self.load_distribution(self.lanti1, self.long1, self.lanti_list[i], self.long_list[i], self.frequ_list[i])
            final_rebuild.append(list(np.array(load_rebuild) * self.adjust_rate))
        return final_rebuild
    
    def location_divide(self):
        """find the near k base_station and get their index
        
        param
        ----------
            location_array : array
                the array to store both the lantitude and longtitude information of the location
                structure : nx2
                content : first column ---> lantitude , second column ---> longitude
            k_near : int
                the number of the near base station we will find
        
        return
        ----------
            list_near : list
                store the k near base stations of each base
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the index of the k near base stations
            list_far : list
                store the other base stations of each base
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the index of the other base stations
            list_assign : list
                store the index of which station provide the flow to one base station
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the indexes of base stations that share the flows to one base station according to the index
        
        refer
        ----------
            Other func :  self.array_calculate_distance
        """
        base_num = len(self.location_array)
        list_near = list()
        list_far = list()
        list_assign = []
            
        for i in range(base_num):
            list_assign.append([])
                
        for index in range(base_num):
            base_index = self.location_array[index]
            base_all = np.ones((base_num, 1)) * base_index
            base_distance = self.array_calculate_distance(base_all, self.location_array)
            distance_sort_index = np.argsort(base_distance)
            distance_near = list(distance_sort_index[0 : self.k_near])
            distance_far = list(distance_sort_index[self.k_near : base_num])
            list_near.append(distance_near)
            list_far.append(distance_far)
            for i in range(self.k_near):
                list_assign[distance_near[i]].append(index)
        return list_near, list_far, list_assign
    
    def server_placement(self):
        """we use the ilp model to test load balance function of each formulation
        
        param
        ----------
            list_near : list
                store the k near base stations of each base
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the index of the k near base stations
            list_far : list
                store the other base stations of each base
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the index of the other base stations
            list_assign : list
                store the index of which station provide the flow to one base station
                structure : [[content1],[content2],.......,[contentn]]
                content : [] ---> the indexes of base stations that share the flows to one base station according to the index
            self.k_place : int
                The number of the servers we want to assign
            self.workload : list
                Store the list of the workload about each base
                
        return
        ----------
            u_list : array
               store the proportion of the assignment in each base station
            y_list : array
               store the distribution of server in each base station
            final_optimal : 1/beta 
               parameter to describe the load balance function

        refer
        ----------
            Other func :  None
            Document : The introduction document of gurobi
        """
        list_near, list_far, list_assign = self.location_divide()
        final_rebuild = self.sum_distribution()
        time_part = len(final_rebuild)
        base_num = len(list_near)
        list_var = [] 
       
        model_placement = gr.Model('Lip_placement')
        u = model_placement.addVars(base_num, len(list_near[0]), vtype='S', lb=0)                
        y = model_placement.addVars(base_num, vtype='I', lb=0)                                           
        a = model_placement.addVar(vtype='S', lb=0)  
                                                  
        model_placement.setObjective(a,gr.GRB.MAXIMIZE)
        model_placement.addConstr(
                sum(y[i] for i in range(base_num)) == self.compute_num
        )  
        model_placement.addConstrs(
                sum(u[i,j] for j in range(len(list_near[0]))) == a 
                for i in range(base_num)
        )
        for k in range(time_part):
            model_placement.addConstrs(
                    sum(u[list_assign[i][j],list_near[list_assign[i][j]].index(i)] * final_rebuild[k][list_assign[i][j]] 
                    for j in range(len(list_assign[i]))) <= y[i] 
                    for i in range(base_num)
             )
        model_placement.optimize()
       
        if model_placement.Status == gr.GRB.OPTIMAL:
            for var in model_placement.getVars():
                list_var.append(var.x)
            u_list = np.array(list_var[0 : base_num*len(list_near[0])])
            y_list = np.array(list_var[base_num*len(list_near[0]) : base_num*len(list_near[0])+base_num])
            u_list = u_list.reshape((base_num, len(list_near[0])))
            final_optimal = list_var[base_num * len(list_near[0]) + base_num]
        return u_list, y_list, final_optimal